import pymysql
import time

'''
    基于pymysql的数据库扩展类,操作mysql数据库和PHP有相同感,PHP中使用关联数组处理,python使用dict处理
'''


class PDO(object):

    __db = None  # pymysql.connect 返回的数据操作对象
    __cursor = None  # 游标  默认使用dict处理数据
    __error_message = ""  # 错误信息
    __error_code = ""  # 错误码
    __error = ""  # 总错误信息
    __sql = ""  # SQL语句
    port = 3306  # 默认端口
    __debug = ""  # 是否开启debug模式
    __start_time = ""  # 执行SQL开始时间
    __end_time = ""  # 执行SQL结束时间
    __affected_rows = 0  # 影响行数
    __timeout_error_value = 1000  # 最大效率查询时间限制 超过此时间(ms)  debug显示用时为红色字体
    __last_id = None  # 上次插入的id
    MAX_LIMIT = 100  # 最大查询条数

    '''
        初始化数据库连接 默认编码UTF-8  默认端口3306  debug默认关闭 commit模式默认自动提交
    '''

    def __init__(self, host, user, password, db, charset="UTF8", port=3306, debug=False, autocommit=True):

        self.__debug = debug
        self.port = port
        try:

            config = {
                "host":host,
                "user":user,
                "password":password,
                "db":db,
                "charset":charset,
                "port":port,
                "cursorclass":pymysql.cursors.DictCursor
            }

            self.__db = pymysql.connect(**config)
            self.__cursor = self.__db.cursor()
            self.__db.autocommit(autocommit)

            if self.__debug:
                self.__start_time = time.time()

        except Exception as e:

            self.__exception(e)

    '''
        错误抛出异常,显示错误信息
    '''

    def __exception(self, e):

        self.__error_code = str(e.args[0])
        self.__error_message = e.args[1]
        error_arr = ["[Err] ",self.__error_code,"-",self.__error_message];
        self.__error = str.join("",error_arr)
        raise Exception(self.__error)

    '''
        返回原生pymysql的连接对象
    '''

    def get_connect(self):

        return self.__db

    '''
        返回connect连接对象的游标对象  connect.cursor()
    '''

    def get_cursor(self):

        return self.__cursor

    '''
        将数据插入到数据库  接收参数为 table(表名)  rows(dict的数组/单个dict)
    '''

    def insert(self, table, rows):

        insert_sql = self.__insert_sql(rows)

        sql_arr = ["INSERT INTO `",table,"` ",insert_sql]
        self.__sql =  str.join("",sql_arr)

        try:

            self.__cursor.execute(self.__sql)
            self.__affected_rows = self.__db.affected_rows()
            self.__last_id = self.__db.insert_id()

            return self.__last_id

        except Exception as e:

            self.__exception(e)

    '''
        拼接INSERT SQL语句
    '''

    def __insert_sql(self, rows):

        temp_rows = []
        if type(rows) is dict:
            temp_rows.append(rows)
        else:
            temp_rows = rows

        first_row = temp_rows[0]
        column_sql = "(" + self.__get_columns(first_row) + ")"

        data_sql_arr = []
        for row in temp_rows:
            row_values = list(row.values())
            for i, v in enumerate(row_values):
                temp_str = str(v)
                temp_str = temp_str.replace("'", "\\'")
                row_values[i] = "'" + temp_str + "'"
            data_sql = "(" + str.join(",", row_values) + ")"
            data_sql_arr.append(data_sql)

        data_sql_all = str.join(",", data_sql_arr)
        sql = column_sql + " VALUES " + data_sql_all

        return sql

    '''
        获取要查询/插入的列名SQL  
    '''

    def __get_columns(self, row):

        keys_arr = row.keys()
        keys_arr = list(keys_arr)

        for i, v in enumerate(keys_arr):
            keys_arr[i] = "`" + v + "`"

        column_sql = str.join(",", keys_arr)

        return column_sql

    '''
        删除信息。传递参数  table(表名)  where(where条件 为一个dict)

    '''

    def delete(self, table, where, limit=1):

        sql_arr=["DELETE  FROM `",table,self.__where_sql(where)," LIMIT ",str(limit)]
        self.__sql =  str.join("",sql_arr)

        try:

            self.__cursor.execute(self.__sql)
            self.__affected_rows = self.__db.affected_rows()

            return self.__db.affected_rows()

        except Exception as e:

            self.__exception(e)

    '''
        将where条件的dict解析为SQL语句
    '''

    def __where_sql(self, where):

        if len(where) == 0:
            return ""
        else:
            sql = []
            for k, v in where.items():

                if k == "OR":
                    continue

                k = k.strip()
                explode_arr = k.split(" ")

                if type(v) is list:

                    temp_sql = ""

                    if len(v) == 0:
                        sql.append(temp_sql)
                    else:

                        in_str = " IN "

                        if len(explode_arr) == 1:

                            for v_index, v_value in enumerate(v):
                                temp_v_value = str(v_value)
                                temp_v_value = temp_v_value.replace("'", "\\'")
                                v[v_index] = "'" + temp_v_value + "'"

                            temp_sql = str.join(",", v)
                            temp_sql = "`" + str(k) + "`" + in_str + " (" + temp_sql + ") "
                            sql.append(temp_sql)

                        else:

                            for v_index, v_value in enumerate(v):
                                temp_v_value = str(v_value)
                                temp_v_value = temp_v_value.replace("'", "\\'")
                                v[v_index] = "'" + temp_v_value + "'"

                            temp_sql = str.join(",", v)
                            column = explode_arr[0]
                            del explode_arr[0]
                            condition = str.join(" ", explode_arr)
                            temp_sql = " `" + str(column) + "` " + str(condition) + " (" + temp_sql + ") "
                            sql.append(temp_sql)

                else:
                    if len(explode_arr) >= 2:

                        temp_v_value = str(v)
                        temp_v_value = temp_v_value.replace("'", "\\'")
                        column = explode_arr[0]
                        del explode_arr[0]
                        condition = str.join(" ", explode_arr)
                        sql.append(" `" + str(column) + "` " + str(condition) + " '" + str(v) + "' ")

                    else:

                        temp_v_value = str(v)
                        temp_v_value = temp_v_value.replace("'", "\\'")
                        sql.append(" `" + str(k) + "` =" + " " + "'" + temp_v_value + "'" + " ")

            if "OR" in where:
                return str.join(" OR ", sql)
            else:
                return str.join(" AND ", sql)

    '''
        更新数据.  table(表名)  update_dict(要更新的数据dict)  where(where条件  dict)
    '''

    def update(self, table, update_dict, where, limit=1):

        sql = "UPDATE `" + table + "` SET " + self.__update_sql(update_dict)
        sql += " WHERE " + self.__where_sql(where) + " LIMIT " + str(limit)
        self.__sql = sql

        try:

            self.__cursor.execute(self.__sql)
            self.__affected_rows = self.__db.affected_rows()

            return self.__db.affected_rows()

        except Exception as e:

            self.__exception(e)

    '''
        获取更新SQL
    '''

    def __update_sql(self, update_dict):

        sql_arr = []
        if type(update_dict) is dict:

            for k, v in update_dict.items():
                sql_arr_item = " `" + str(k) + "` = '" + str(v) + "' "
                sql_arr.append(sql_arr_item)

            return str.join(",", sql_arr)

        else:

            return ""

    '''
        SELECT查询语句。  cols是list,要取的列名集合。  table(表名)  where(where条件 dict e.g {"id":1}) id=1
        order (排序 dict,e.g: {"id":"DESC","created_at":"ASC"})
        offset  limit  限制查询条数 
    '''

    def select(self, cols, table, where={}, order={}, offset=0, limit=100):

        need_column = ""
        if type(cols) is list and len(cols) != 0:

            for i, v in enumerate(cols):
                cols[i] = "`" + str(v) + "`"
            need_column = str.join(",", cols)

        elif type(cols) == str:
            need_column = cols
        else:
            need_column = "*"

        order_sql = ""
        if type(order) is dict and len(order) != 0:

            order_arr = []
            for col, sort in order.items():
                order_arr_item = "  `" + str(col) + "` " + str(sort) + " "
                order_arr.append(order_arr_item)
            order_sql = str.join(",", order_arr)
            order_sql = " ORDER BY " + order_sql

        else:

            order_sql = ""

        where_sql = " "
        if len(where):
            where_sql = " WHERE " + self.__where_sql(where)

        limit = min(limit,self.MAX_LIMIT)
        self.__sql = "SELECT " + need_column + " FROM `" + table + "`" + where_sql
        self.__sql += order_sql + " LIMIT " + str(offset) + "," + str(limit)

        try:

            self.__cursor.execute(self.__sql)
            self.__affected_rows = self.__db.affected_rows()

            return self.__cursor.fetchall()

        except Exception as e:

            self.__exception(e)

    '''
        执行原生查询SQL语句  select 
    '''

    def query(self, sql):

        try:

            self.__sql = sql
            self.__cursor.execute(self.__sql)
            self.__affected_rows = self.__db.affected_rows()

            return self.__cursor.fetchall()

        except Exception as e:

            self.__exception(e)

    '''
        执行原生操作语句 insert  update  delete
    '''

    def execute(self, sql, ret_last_id=False):

        try:

            self.__sql = sql
            self.__cursor.execute(self.__sql)
            self.__affected_rows = self.__db.affected_rows()

            if ret_last_id:
                self.__last_id = self.__db.insert_id()
                return self.__last_id

            return self.__db.affected_rows()

        except Exception as e:

            self.__exception(e)

    '''
        获取本次操作的SQL语句
    '''

    def count(self, table, where={}):

        where_sql = ""

        if where:
            where_sql = " WHERE "+self.__where_sql(where)

        self.__sql = "SELECT COUNT(*) as num FROM `"+table+"` "+where_sql

        try:
            self.__cursor.execute(self.__sql)
            row = self.__cursor.fetchone()
            return row["num"]
        except Exception as e:
            self.__exception(e)

    def get_sql(self):

        return self.__sql

    '''
        析构函数 若是debug模式开启 则打印出SQL语句  影响条数 (最后插入的id) 操作执行时间(ms)
    '''

    def __del__(self):

        if self.__debug:

            self.__end_time = time.time()
            use_time = self.__end_time - self.__start_time
            use_time = use_time * 1000
            use_time = int(use_time)

            # 打印log信息  颜色为青色  错误/时间超过默认1000ms 变为红色
            print("\033[32;0m[SQL] " + self.__sql + "\033[0m")  # SQL语句
            print("\033[32;0m[affected_rows] " + str(self.__affected_rows) + "\033[0m") #影响行数
            if self.__last_id:
                print("\033[32;0m[last_insert_id] " + str(self.__last_id) + "\033[0m") # 最后插入的id
            if use_time < self.__timeout_error_value:
                print("\033[32;0m[time] " + str(use_time) + " ms\033[0m") # 执行时间
            else:
                print("\033[31;0m[time] " + str(use_time) + " ms")
            if self.__error:
                print("\033[1;31;0m" + self.__error + "\033[0m")